import json
import os

import matplotlib.pyplot as plt
import numpy as np
import argparse
from coco import COCOLoader
from datasets import Dataset
from json_utils import convert_to_json_serializable
from PIL import Image
from tqdm import tqdm
import base64
from openai import OpenAI
import anthropic
import requests


def answer_cleaning(text):
    return text.strip().lower().replace("-", " ")


def evaluate_correctness(generated_text, answers):
    """
    Evaluate if the generated text contains or is contained by any of the possible answers.

    Args:
        generated_text: The text generated by the model
        answers: List of possible correct answers

    Returns:
        bool: True if the generated text matches any of the answers, False otherwise
    """
    is_correct = False
    cleaned_text = answer_cleaning(generated_text)
    for answer in answers:
        answer = answer_cleaning(answer)
        if answer in cleaned_text or cleaned_text in answer:
            is_correct = True
            break

    return is_correct


def print_correctness_distribution_table(evaluated_results):
    """
    Print a table showing the distribution of correctness cases across question types.

    Args:
        evaluated_results: List of dictionaries containing evaluation results with keys
                          'first_hop_correct', 'second_hop_correct', 'full_correct', and 'question_type'
    """
    print("\n=== Correctness Distribution Across Question Types ===")

    # Initialize counters for each combination of correctness and question type
    question_types = set(data["question_type"] for data in evaluated_results)
    correctness_cases = [
        (True, True, True),  # first_hop correct, second_hop correct, full correct
        (True, True, False),  # first_hop correct, second_hop correct, full incorrect
        (True, False, True),  # first_hop correct, second_hop incorrect, full correct
        (True, False, False),  # first_hop correct, second_hop incorrect, full incorrect
        (False, True, True),  # first_hop incorrect, second_hop correct, full correct
        (False, True, False),  # first_hop incorrect, second_hop correct, full incorrect
        (False, False, True),  # first_hop incorrect, second_hop incorrect, full correct
        (
            False,
            False,
            False,
        ),  # first_hop incorrect, second_hop incorrect, full incorrect
    ]

    # Create a dictionary to store counts for each combination
    counts = {
        case: {qtype: 0 for qtype in question_types} for case in correctness_cases
    }

    # Count occurrences of each combination
    for data in evaluated_results:
        case = (
            data["first_hop_correct"],
            data["second_hop_correct"],
            data["full_correct"],
        )
        qtype = data["question_type"]
        counts[case][qtype] += 1

    # Print the table header
    print("\nCorrectness Distribution Table:")
    print(
        f"{'First':^7} | {'Second':^7} | {'Full':^7} | "
        + " | ".join(f"{qtype:^7}" for qtype in sorted(question_types))
        + " | Total"
    )
    print(
        "-" * 7
        + "-+-"
        + "-" * 7
        + "-+-"
        + "-" * 7
        + "-+-"
        + "-+-".join(["-" * 7 for _ in question_types])
        + "-+-"
        + "-" * 7
    )

    # Print each row of the table
    for case in correctness_cases:
        first, second, full = case
        row_total = sum(counts[case].values())
        print(
            f"{str(first):^7} | {str(second):^7} | {str(full):^7} | "
            + " | ".join(
                f"{counts[case][qtype]:^7}" for qtype in sorted(question_types)
            )
            + f" | {row_total:^5}"
        )

    # Print column totals
    col_totals = {
        qtype: sum(counts[case][qtype] for case in correctness_cases)
        for qtype in question_types
    }
    total_sum = sum(col_totals.values())
    print(
        "-" * 7
        + "-+-"
        + "-" * 7
        + "-+-"
        + "-" * 7
        + "-+-"
        + "-+-".join(["-" * 7 for _ in question_types])
        + "-+-"
        + "-" * 7
    )
    print(
        f"{'Total':^7} | {' ':^7} | {' ':^7} | "
        + " | ".join(f"{col_totals[qtype]:^7}" for qtype in sorted(question_types))
        + f" | {total_sum:^5}"
    )

    return counts, col_totals, total_sum


# Function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


class APILLM:
    def __init__(self, model_id):
        if model_id in ["gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14"]:
            self.api_key = os.getenv('OPENAI_API_KEY')
        elif model_id in ["claude-sonnet-4-20250514", "claude-3-5-haiku-20241022"]:
            self.api_key = os.getenv('ANTHROPIC_API_KEY')
            self.client = anthropic.Anthropic()
        self.model_id = model_id

    def generate(self, text, image_path, **kwargs):
        if self.model_id in ["gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14"]:
            return self.generate_oai(text, image_path, **kwargs)
        elif self.model_id in ["claude-sonnet-4-20250514", "claude-3-5-haiku-20241022"]:
            return self.generate_anthropic(text, image_path, **kwargs)
        else:
            raise ValueError(f"Unknown model ID: {self.model_id}")

    def generate_oai(self, text, image_path, **kwargs):
        temperature = kwargs['temperature'] if 'temperature' in kwargs else 0
        if image_path:
            base64_image = encode_image(image_path)
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }

        message_content = [
            {
                "type": "text",
                "text": text
            },
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            }
        ] if image_path else [{"type": "text", "text": text}]

        payload = {
            "model": self.model_id,
            "messages": [
                {
                    "role": "user",
                    "content": message_content,
                }
            ],
            "max_tokens": 300,
            "temperature": temperature
        }

        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        responses = [resp['message']['content'] for resp in response.json()['choices']]
        return responses[0]

    def generate_anthropic(self, text, image_path, **kwargs):
        temperature = kwargs['temperature'] if 'temperature' in kwargs else 0
        if image_path:
            base64_image = encode_image(image_path)

        message_content = [
            {
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/jpeg",
                    "data": base64_image,
                },
            },
            {
                "type": "text",
                "text": text
            }
        ] if image_path else [{"type": "text", "text": text}]

        message = self.client.messages.create(
            model=self.model_id,
            max_tokens=300,
            messages=[
                {
                    "role": "user",
                    "content": message_content,
                }
            ],
            temperature=temperature,
        )
        return message.content[0].text



def parse_args():
    parser = argparse.ArgumentParser(description='Parallel evaluation of generated data')
    parser.add_argument('--model_id', type=str, default="llava-hf/llava-1.5-7b-hf", help='Model ID')
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    model_id = args.model_id

    model = APILLM(model_id)

    generated_data_file = (
        # "output/object_identification/object_identification_val2014_image_input_3_gpt5_0_300.jsonl"
        "output/object_identification/object_identification_val2017_image_input_4_gpt5_0_5000.jsonl"
    )
    with open(generated_data_file, "r") as f:
        generated_data = [json.loads(line) for line in f]
    coco = COCOLoader(
        dataType="val2017" if "val2017" in generated_data_file else "val2014"
    )
    evaluate_result_dir = f'output/object_identification/evaluate_result_{model_id.split("/")[-1]}_{generated_data_file.split("/")[-1].split(".")[0]}'
    os.makedirs(evaluate_result_dir, exist_ok=True)

    evaluated_results = []
    # hidden_states_dataset = []
    for data_idx, data in enumerate(tqdm(generated_data)):
        question_type = data["question_type"]
        img_id = data["img_id"]
        first_hop_question = data["first_hop_question"]
        first_hop_answer = data["first_hop_answer"]
        first_hop_wrong_answers = data["first_hop_wrong_answers"]
        second_hop_question_template = data["second_hop_question_template"]
        full_question = data["full_question"]
        second_hop_answer = data["second_hop_answer"]
        second_hop_wrong_answers = data["second_hop_wrong_answers"]

        try:

            first_hop_generated_text = model.generate(
                text="Answer in 1 word: " + first_hop_question,
                image_path=coco.get_img_path(img_id),
            )

            full_generated_text = model.generate(
                text="Answer in 1 word: " + full_question,
                image_path=coco.get_img_path(img_id),
            )

            # Evaluate first hop and second hop correctness separately
            first_hop_correct = evaluate_correctness(
                first_hop_generated_text, first_hop_answer
            )
            full_correct = evaluate_correctness(full_generated_text, second_hop_answer)

            if first_hop_correct:
                second_hop_question = second_hop_question_template.format(
                    first_hop_generated_text.strip().lower()
                )
            else:
                second_hop_question = second_hop_question_template.format(
                    first_hop_answer[0].strip().lower()
                )

            second_hop_generated_text = model.generate(
                text="Answer in 1 word: " + second_hop_question,
                image_path=None,
            )

            second_hop_correct = evaluate_correctness(
                second_hop_generated_text, second_hop_answer
            )

            variant_1_question = full_question + f" ({first_hop_answer[0].strip().lower()})"
            variant_1_generated_text = model.generate(
                text="Answer in 1 word: " + variant_1_question,
                image_path=coco.get_img_path(img_id),
            )

            variant_1_correct = evaluate_correctness(
                variant_1_generated_text, second_hop_answer
            )

            variant_2_generated_text = model.generate(
                text="Answer in 1 word: " + second_hop_question,
                image_path=coco.get_img_path(img_id),
            )

            variant_2_correct = evaluate_correctness(
                variant_2_generated_text, second_hop_answer
            )

            variant_3_generated_text = model.generate(
                text=full_question,
                image_path=coco.get_img_path(img_id),
            )

            variant_3_correct = evaluate_correctness(
                variant_3_generated_text, second_hop_answer
            )
        except:
            continue

        data["first_hop_generated_text"] = first_hop_generated_text
        data["second_hop_generated_text"] = second_hop_generated_text
        data["full_generated_text"] = full_generated_text
        data["variant_1_generated_text"] = variant_1_generated_text
        data["variant_2_generated_text"] = variant_2_generated_text
        data["variant_3_generated_text"] = variant_3_generated_text
        data["first_hop_correct"] = first_hop_correct
        data["second_hop_correct"] = second_hop_correct
        data["full_correct"] = full_correct
        data["variant_1_correct"] = variant_1_correct
        data["variant_2_correct"] = variant_2_correct
        data["variant_3_correct"] = variant_3_correct

        evaluated_results.append(data)

        if data_idx in [100, 1000, 2000]:
            evaluate_result_file = os.path.join(
                evaluate_result_dir, generated_data_file.split("/")[-1].replace(".jsonl", f"_{data_idx}.jsonl")
            )
            with open(evaluate_result_file, "w") as f:
                for data in evaluated_results:
                    f.write(json.dumps(data) + "\n")

    evaluate_result_file = os.path.join(
        evaluate_result_dir, generated_data_file.split("/")[-1]
    )
    with open(evaluate_result_file, "w") as f:
        for data in evaluated_results:
            f.write(json.dumps(data) + "\n")

    # hidden_states_dataset = Dataset.from_list(hidden_states_dataset)
    # hidden_states_dataset.save_to_disk(
    #     os.path.join(evaluate_result_dir, "hidden_states_dataset")
    # )

    with open(evaluate_result_file, "r") as f:
        evaluated_results = [json.loads(line) for line in f]

    # Print the correctness distribution table
    print_correctness_distribution_table(evaluated_results)
